{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Ke Chen\n",
    "# knutchen@ucsd.edu\n",
    "# HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION\n",
    "# Compute the F1-score in DESED\n",
    "import os.path as osp\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.patches import Rectangle, Patch\n",
    "from psds_eval import PSDSEval, plot_psd_roc, plot_per_class_psd_roc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "groundtruth = pd.read_csv(\"heatmap_output/answer.tsv\", sep=\"\\t\")\n",
    "metadata = pd.read_csv(\"heatmap_output/answer_meta.tsv\", sep=\"\\t\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "psds_eval = PSDSEval(ground_truth=groundtruth, metadata=metadata)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "macro F-score: 35.08\n",
      "  Alarm_bell_ringing: 34.33\n",
      "  Blender: 42.35\n",
      "  Cat: 36.31\n",
      "  Dishes: 17.60\n",
      "  Dog: 35.82\n",
      "  Electric_shaver_toothbrush: 23.81\n",
      "  Frying: 9.30\n",
      "  Running_water: 30.58\n",
      "  Speech: 69.68\n",
      "  Vacuum_cleaner: 51.01\n"
     ]
    }
   ],
   "source": [
    "\n",
    "det = pd.read_csv(\"heatmap_output/pann_desed_outputmap.tsv\", sep=\"\\t\")\n",
    "macro_f, class_f = psds_eval.compute_macro_f_score(det)\n",
    "print(f\"macro F-score: {macro_f*100:.2f}\")\n",
    "for clsname, f in class_f.items():\n",
    "    print(f\"  {clsname}: {f*100:.2f}\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "macro F-score: 48.42\n",
      "  Alarm_bell_ringing: 48.56\n",
      "  Blender: 52.94\n",
      "  Cat: 67.67\n",
      "  Dishes: 24.96\n",
      "  Dog: 47.99\n",
      "  Electric_shaver_toothbrush: 42.92\n",
      "  Frying: 60.26\n",
      "  Running_water: 43.03\n",
      "  Speech: 46.78\n",
      "  Vacuum_cleaner: 49.12\n"
     ]
    }
   ],
   "source": [
    "\n",
    "det = pd.read_csv(\"heatmap_output/htsat-test_outputmap.tsv\", sep=\"\\t\")\n",
    "macro_f, class_f = psds_eval.compute_macro_f_score(det)\n",
    "print(f\"macro F-score: {macro_f*100:.2f}\")\n",
    "for clsname, f in class_f.items():\n",
    "    print(f\"  {clsname}: {f*100:.2f}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "macro F-score: 50.67\n",
      "  Alarm_bell_ringing: 47.53\n",
      "  Blender: 55.06\n",
      "  Cat: 72.39\n",
      "  Dishes: 30.90\n",
      "  Dog: 49.67\n",
      "  Electric_shaver_toothbrush: 41.86\n",
      "  Frying: 63.16\n",
      "  Running_water: 44.27\n",
      "  Speech: 51.26\n",
      "  Vacuum_cleaner: 50.63\n"
     ]
    }
   ],
   "source": [
    "det = pd.read_csv(\"heatmap_output/htsat-test-ensemble_outputmap.tsv\", sep=\"\\t\")\n",
    "macro_f, class_f = psds_eval.compute_macro_f_score(det)\n",
    "print(f\"macro F-score: {macro_f*100:.2f}\")\n",
    "for clsname, f in class_f.items():\n",
    "    print(f\"  {clsname}: {f*100:.2f}\")\n"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "4050b233ab8e04dc04d998d60c703de3208e915a6def4e02bc2ea8bbb172af4e"
  },
  "kernelspec": {
   "display_name": "Python 3.7.9 64-bit ('kechen_py37': conda)",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  },
  "metadata": {
   "interpreter": {
    "hash": "98b0a9b7b4eaaa670588a142fd0a9b87eaafe866f1db4228be72b4211d12040f"
   }
  },
  "orig_nbformat": 2
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
